# Adapted from openai/api_server.py and tgi-kai-bridge

import argparse
import asyncio
import json
import os
from http import HTTPStatus
from typing import AsyncGenerator, List, Tuple

import fastapi
import uvicorn
from fastapi import APIRouter, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from loguru import logger
from prometheus_client import make_asgi_app

from aphrodite.common.outputs import RequestOutput
from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
from aphrodite.common.utils import random_uuid
from aphrodite.endpoints.kobold.protocol import KAIGenerationInputSchema
from aphrodite.engine.args_tools import AsyncEngineArgs
from aphrodite.engine.async_aphrodite import AsyncAphrodite
from aphrodite.transformers_utils.tokenizer import get_tokenizer

TIMEOUT_KEEP_ALIVE = 5  # seconds

served_model: str = "Read Only"
engine: AsyncAphrodite = None
gen_cache: dict = {}
app = fastapi.FastAPI()

badwordsids: List[int] = []

# Add prometheus asgi middleware to route /metrics/ requests
metrics_app = make_asgi_app()
app.mount("/metrics/", metrics_app)


def _set_badwords(tokenizer, hf_config):  # pylint: disable=redefined-outer-name
    global badwordsids
    if hf_config.bad_words_ids is not None:
        badwordsids = hf_config.bad_words_ids
        return

    badwordsids = [
        v for k, v in tokenizer.get_vocab().items()
        if any(c in str(k) for c in "[]")
    ]
    if tokenizer.pad_token_id in badwordsids:
        badwordsids.remove(tokenizer.pad_token_id)
    badwordsids.append(tokenizer.eos_token_id)


kai_api = APIRouter()
extra_api = APIRouter()
kobold_lite_ui = ""

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


def create_error_response(status_code: HTTPStatus,
                          message: str) -> JSONResponse:
    return JSONResponse({
        "msg": message,
        "type": "invalid_request_error"
    },
                        status_code=status_code.value)


@app.exception_handler(ValueError)
def validation_exception_handler(request, exc):  # pylint: disable=unused-argument
    return create_error_response(HTTPStatus.UNPROCESSABLE_ENTITY, str(exc))


def prepare_engine_payload(
        kai_payload: KAIGenerationInputSchema
) -> Tuple[SamplingParams, List[int]]:
    """Create SamplingParams and truncated input tokens for AsyncEngine"""

    if not kai_payload.genkey:
        kai_payload.genkey = f"kai-{random_uuid()}"

    if kai_payload.max_context_length > max_model_len:
        raise ValueError(
            f"max_context_length ({kai_payload.max_context_length}) "
            "must be less than or equal to "
            f"max_model_len ({max_model_len})")

    # KAIspec: top_k == 0 means disabled, aphrodite: top_k == -1 means disabled
    # https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings
    kai_payload.top_k = kai_payload.top_k if kai_payload.top_k != 0.0 else -1
    kai_payload.tfs = max(_SAMPLING_EPS, kai_payload.tfs)
    if kai_payload.temperature < _SAMPLING_EPS:
        # temp < _SAMPLING_EPS: greedy sampling
        kai_payload.n = 1
        kai_payload.top_p = 1.0
        kai_payload.top_k = -1

    if kai_payload.dynatemp_range is not None:
        dynatemp_min = kai_payload.temperature - kai_payload.dynatemp_range
        dynatemp_max = kai_payload.temperature + kai_payload.dynatemp_range

    sampling_params = SamplingParams(
        n=kai_payload.n,
        best_of=kai_payload.n,
        repetition_penalty=kai_payload.rep_pen,
        temperature=kai_payload.temperature,
        dynatemp_min=dynatemp_min if kai_payload.dynatemp_range > 0 else 0.0,
        dynatemp_max=dynatemp_max if kai_payload.dynatemp_range > 0 else 0.0,
        dynatemp_exponent=kai_payload.dynatemp_exponent,
        smoothing_factor=kai_payload.smoothing_factor,
        smoothing_curve=kai_payload.smoothing_curve,
        tfs=kai_payload.tfs,
        top_p=kai_payload.top_p,
        top_k=kai_payload.top_k,
        top_a=kai_payload.top_a,
        min_p=kai_payload.min_p,
        typical_p=kai_payload.typical,
        eta_cutoff=kai_payload.eta_cutoff,
        epsilon_cutoff=kai_payload.eps_cutoff,
        mirostat_mode=kai_payload.mirostat,
        mirostat_tau=kai_payload.mirostat_tau,
        mirostat_eta=kai_payload.mirostat_eta,
        stop=kai_payload.stop_sequence,
        include_stop_str_in_output=kai_payload.include_stop_str_in_output,
        custom_token_bans=badwordsids
        if kai_payload.use_default_badwordsids else [],
        max_tokens=kai_payload.max_length,
        seed=kai_payload.sampler_seed,
    )

    max_input_tokens = max(
        1, kai_payload.max_context_length - kai_payload.max_length)
    input_tokens = tokenizer(kai_payload.prompt).input_ids[-max_input_tokens:]

    return sampling_params, input_tokens


@kai_api.post("/generate")
async def generate(kai_payload: KAIGenerationInputSchema) -> JSONResponse:
    """Generate text"""

    sampling_params, input_tokens = prepare_engine_payload(kai_payload)
    result_generator = engine.generate(None, sampling_params,
                                       kai_payload.genkey, input_tokens)

    final_res: RequestOutput = None
    previous_output = ""
    async for res in result_generator:
        final_res = res
        new_chunk = res.outputs[0].text[len(previous_output):]
        previous_output += new_chunk
        gen_cache[kai_payload.genkey] = previous_output

    assert final_res is not None
    del gen_cache[kai_payload.genkey]

    return JSONResponse(
        {"results": [{
            "text": output.text
        } for output in final_res.outputs]})


@extra_api.post("/generate/stream")
async def generate_stream(
        kai_payload: KAIGenerationInputSchema) -> StreamingResponse:
    """Generate text SSE streaming"""

    sampling_params, input_tokens = prepare_engine_payload(kai_payload)
    results_generator = engine.generate(None, sampling_params,
                                        kai_payload.genkey, input_tokens)

    async def stream_kobold() -> AsyncGenerator[bytes, None]:
        previous_output = ""
        async for res in results_generator:
            new_chunk = res.outputs[0].text[len(previous_output):]
            previous_output += new_chunk
            yield b"event: message\n"
            yield f"data: {json.dumps({'token': new_chunk})}\n\n".encode()

    return StreamingResponse(stream_kobold(),
                             headers={
                                 "Cache-Control": "no-cache",
                                 "Connection": "keep-alive"
                             },
                             media_type="text/event-stream")


@extra_api.post("/generate/check")
@extra_api.get("/generate/check")
async def check_generation(request: Request):
    """Check outputs in progress (poll streaming)"""

    text = ""
    try:
        request_dict = await request.json()
        if "genkey" in request_dict and request_dict["genkey"] in gen_cache:
            text = gen_cache[request_dict["genkey"]]
    except json.JSONDecodeError:
        pass

    return JSONResponse({"results": [{"text": text}]})


@extra_api.post("/abort")
async def abort_generation(request: Request):
    """Abort running generation"""
    try:
        request_dict = await request.json()
        if "genkey" in request_dict:
            await engine.abort(request_dict["genkey"])
    except json.JSONDecodeError:
        pass

    return JSONResponse({})


@extra_api.post("/tokencount")
async def count_tokens(request: Request):
    """Tokenize string and return token count"""

    request_dict = await request.json()
    tokenizer_result = tokenizer(request_dict["prompt"])
    return JSONResponse({"value": len(tokenizer_result.input_ids)})


@kai_api.get("/info/version")
async def get_version():
    """Impersonate KAI"""
    return JSONResponse({"result": "1.2.4"})


@kai_api.get("/model")
async def get_model():
    """Get current model"""
    return JSONResponse({"result": f"aphrodite/{served_model}"})


@kai_api.get("/config/soft_prompts_list")
async def get_available_softprompts():
    """Stub for compatibility"""
    return JSONResponse({"values": []})


@kai_api.get("/config/soft_prompt")
async def get_current_softprompt():
    """Stub for compatibility"""
    return JSONResponse({"value": ""})


@kai_api.put("/config/soft_prompt")
async def set_current_softprompt():
    """Stub for compatibility"""
    return JSONResponse({})


@kai_api.get("/config/max_length")
async def get_max_length() -> JSONResponse:
    """Return the configured max output length"""
    max_length = args.max_length
    return JSONResponse({"value": max_length})


@kai_api.get("/config/max_context_length")
@extra_api.get("/true_max_context_length")
async def get_max_context_length() -> JSONResponse:
    """Return the max context length based on the EngineArgs configuration."""
    max_context_length = engine_model_config.max_model_len
    return JSONResponse({"value": max_context_length})


@extra_api.get("/preloadstory")
async def get_preloaded_story() -> JSONResponse:
    """Stub for compatibility"""
    return JSONResponse({})


@extra_api.get("/version")
async def get_extra_version():
    """Impersonate KoboldCpp"""
    return JSONResponse({"result": "KoboldCpp", "version": "1.55.1"})


@app.get("/")
async def get_kobold_lite_ui():
    """Serves a cached copy of the Kobold Lite UI, loading it from disk on
    demand if needed."""
    # read and return embedded kobold lite
    global kobold_lite_ui
    if kobold_lite_ui == "":
        scriptpath = os.path.dirname(os.path.abspath(__file__))
        klitepath = os.path.join(scriptpath, "klite.embd")
        if os.path.exists(klitepath):
            with open(klitepath, "r") as f:
                kobold_lite_ui = f.read()
        else:
            print("Embedded Kobold Lite not found")
    return HTMLResponse(content=kobold_lite_ui)


@app.get("/health")
async def health() -> Response:
    """Health check route for K8s"""
    return Response(status_code=200)


app.include_router(kai_api, prefix="/api/v1")
app.include_router(kai_api, prefix="/api/latest", include_in_schema=False)
app.include_router(extra_api, prefix="/api/extra")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Aphrodite KoboldAI-Compatible RESTful API server.")
    parser.add_argument("--host",
                        type=str,
                        default="localhost",
                        help="host name")
    parser.add_argument("--port", type=int, default=2242, help="port number")
    parser.add_argument("--served-model-name",
                        type=str,
                        default=None,
                        help="The model name used in the API. If not "
                        "specified, the model name will be the same as "
                        "the huggingface name.")
    parser.add_argument("--max-length",
                        type=int,
                        default=256,
                        help="The maximum length of the generated text. "
                        "For use with Kobold Horde.")

    parser = AsyncEngineArgs.add_cli_args(parser)
    args = parser.parse_args()

    logger.debug(f"args: {args}")
    logger.warning("The standalone Kobold API is deprecated and will not "
                   "receive updates. Please use the OpenAI API with the "
                   "--launch-kobold-api flag instead.")

    if args.served_model_name is not None:
        served_model = args.served_model_name
    else:
        served_model = args.model

    engine_args = AsyncEngineArgs.from_cli_args(args)
    engine = AsyncAphrodite.from_engine_args(engine_args)
    engine_model_config = asyncio.run(engine.get_model_config())
    max_model_len = engine_model_config.max_model_len

    # A separate tokenizer to map token IDs to strings.
    tokenizer = get_tokenizer(engine_args.tokenizer,
                              tokenizer_mode=engine_args.tokenizer_mode,
                              trust_remote_code=engine_args.trust_remote_code)

    _set_badwords(tokenizer, engine_model_config.hf_config)

    uvicorn.run(app,
                host=args.host,
                port=args.port,
                log_level="info",
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
